clear; close all; clc;
% Clear workspace, close all figures, and clear command window

% Load data
train_filename = "iris.xlsx";   % Training dataset file name
test_filename = "iris.xlsx";    % Testing dataset file name

% Alternative dataset (commented out)
% train_filename = "bayes_train.xlsx";
% test_filename = "bayes_test.xlsx";

root = find_root();   % Find the root directory for data files
train_filename = fullfile(root, train_filename);  % Construct full path for training data
test_filename = fullfile(root, test_filename);    % Construct full path for testing data
[sample_num, feature_num, X, Y] = read_data(train_filename);  % Read data and return sample size, feature size, data matrix, and label vector

%% Set model parameters
% Available methods: 'normal', 'mvmn', 'kernel', 'mn'
method = 'normal';  % Specify the naive Bayes method to use

%% Plot prior Gaussian distributions (only applicable if method is 'normal')
draw_prior_all(X, Y, method, feature_num);  % Visualize the prior distribution of each class

%% Train model and predict on the training set
if strcmp(method, 'mvmn')
    % If using Multivariate Multinomial model, all predictors are categorical
    Mdl = fitcnb(X, Y, 'DistributionNames', method, 'CategoricalPredictors', "all");
else
    % Otherwise, fit Naive Bayes using specified method
    Mdl = fitcnb(X, Y, 'DistributionNames', method);
end
[label, ~] = resubPredict(Mdl);  % Predict labels on the training data

%% Plot training results: confusion matrix and accuracy
draw_data(label, Y, false);  % Compare predicted and actual labels for training data

%% Plot feature importance (Method 1: SHAP, requires R2021b or later)
% explainer = shapley(Mdl, X);     % Create SHAP explainer object
% plot(explainer);                % Plot SHAP values (feature contributions)

%% Plot feature importance (Method 2: custom importance ranking)
want_num = round(0.8 * feature_num);  % Define number of top important features to show
mat_loss = draw_importance(X, Y, method, feature_num, want_num);  % Plot importance based on custom metric

%% Predict on test set
[groundtruth, pred, Score, test_data] = predict_test(Mdl, test_filename, feature_num);  % Predict on test data

%% Plot confusion matrix for test set and show accuracy
draw_data(label, Y, true);  % Visualize performance on test data

%% Save prediction results to file
save_file(test_data, Score, groundtruth, feature_num);  % Save prediction scores and ground truth to a file
